{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin\n",
    "\n",
    "from common.skeleton import Skeleton\n",
    "import numpy as np\n",
    "import os\n",
    "from common.quaternion import *\n",
    "from paramUtil import *\n",
    "\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def uniform_skeleton(positions, target_offset):\n",
    "    src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')\n",
    "    src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))\n",
    "    src_offset = src_offset.numpy()\n",
    "    tgt_offset = target_offset.numpy()\n",
    "    # print(src_offset)\n",
    "    # print(tgt_offset)\n",
    "    '''Calculate Scale Ratio as the ratio of legs'''\n",
    "    src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()\n",
    "    tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()\n",
    "\n",
    "    scale_rt = tgt_leg_len / src_leg_len\n",
    "    # print(scale_rt)\n",
    "    src_root_pos = positions[:, 0]\n",
    "    tgt_root_pos = src_root_pos * scale_rt\n",
    "\n",
    "    '''Inverse Kinematics'''\n",
    "    quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)\n",
    "    # print(quat_params.shape)\n",
    "\n",
    "    '''Forward Kinematics'''\n",
    "    src_skel.set_offset(target_offset)\n",
    "    new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)\n",
    "    return new_joints\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_file(positions, feet_thre):\n",
    "    # (seq_len, joints_num, 3)\n",
    "    #     '''Down Sample'''\n",
    "    #     positions = positions[::ds_num]\n",
    "\n",
    "    '''Uniform Skeleton'''\n",
    "    positions = uniform_skeleton(positions, tgt_offsets)\n",
    "\n",
    "    '''Put on Floor'''\n",
    "    floor_height = positions.min(axis=0).min(axis=0)[1]\n",
    "    positions[:, :, 1] -= floor_height\n",
    "    #     print(floor_height)\n",
    "\n",
    "    #     plot_3d_motion(\"./positions_1.mp4\", kinematic_chain, positions, 'title', fps=20)\n",
    "\n",
    "    '''XZ at origin'''\n",
    "    root_pos_init = positions[0]\n",
    "    root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])\n",
    "    positions = positions - root_pose_init_xz\n",
    "\n",
    "    # '''Move the first pose to origin '''\n",
    "    # root_pos_init = positions[0]\n",
    "    # positions = positions - root_pos_init[0]\n",
    "\n",
    "    '''All initially face Z+'''\n",
    "    r_hip, l_hip, sdr_r, sdr_l = face_joint_indx\n",
    "    across1 = root_pos_init[r_hip] - root_pos_init[l_hip]\n",
    "    across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]\n",
    "    across = across1 + across2\n",
    "    across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]\n",
    "\n",
    "    # forward (3,), rotate around y-axis\n",
    "    forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)\n",
    "    # forward (3,)\n",
    "    forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]\n",
    "\n",
    "    #     print(forward_init)\n",
    "\n",
    "    target = np.array([[0, 0, 1]])\n",
    "    root_quat_init = qbetween_np(forward_init, target)\n",
    "    root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init\n",
    "\n",
    "    positions_b = positions.copy()\n",
    "\n",
    "    positions = qrot_np(root_quat_init, positions)\n",
    "\n",
    "    #     plot_3d_motion(\"./positions_2.mp4\", kinematic_chain, positions, 'title', fps=20)\n",
    "\n",
    "    '''New ground truth positions'''\n",
    "    global_positions = positions.copy()\n",
    "\n",
    "    # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')\n",
    "    # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r')\n",
    "    # plt.xlabel('x')\n",
    "    # plt.ylabel('z')\n",
    "    # plt.axis('equal')\n",
    "    # plt.show()\n",
    "\n",
    "    \"\"\" Get Foot Contacts \"\"\"\n",
    "\n",
    "    def foot_detect(positions, thres):\n",
    "        velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])\n",
    "\n",
    "        feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2\n",
    "        feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2\n",
    "        feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2\n",
    "        #     feet_l_h = positions[:-1,fid_l,1]\n",
    "        #     feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)\n",
    "        feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float32)\n",
    "\n",
    "        feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2\n",
    "        feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2\n",
    "        feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2\n",
    "        #     feet_r_h = positions[:-1,fid_r,1]\n",
    "        #     feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)\n",
    "        feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float32)\n",
    "        return feet_l, feet_r\n",
    "    #\n",
    "    feet_l, feet_r = foot_detect(positions, feet_thre)\n",
    "    # feet_l, feet_r = foot_detect(positions, 0.002)\n",
    "\n",
    "    '''Quaternion and Cartesian representation'''\n",
    "    r_rot = None\n",
    "\n",
    "    def get_rifke(positions):\n",
    "        '''Local pose'''\n",
    "        positions[..., 0] -= positions[:, 0:1, 0]\n",
    "        positions[..., 2] -= positions[:, 0:1, 2]\n",
    "        '''All pose face Z+'''\n",
    "        positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)\n",
    "        return positions\n",
    "\n",
    "    def get_quaternion(positions):\n",
    "        skel = Skeleton(n_raw_offsets, kinematic_chain, \"cpu\")\n",
    "        # (seq_len, joints_num, 4)\n",
    "        quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)\n",
    "\n",
    "        '''Fix Quaternion Discontinuity'''\n",
    "        quat_params = qfix(quat_params)\n",
    "        # (seq_len, 4)\n",
    "        r_rot = quat_params[:, 0].copy()\n",
    "        #     print(r_rot[0])\n",
    "        '''Root Linear Velocity'''\n",
    "        # (seq_len - 1, 3)\n",
    "        velocity = (positions[1:, 0] - positions[:-1, 0]).copy()\n",
    "        #     print(r_rot.shape, velocity.shape)\n",
    "        velocity = qrot_np(r_rot[1:], velocity)\n",
    "        '''Root Angular Velocity'''\n",
    "        # (seq_len - 1, 4)\n",
    "        r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))\n",
    "        quat_params[1:, 0] = r_velocity\n",
    "        # (seq_len, joints_num, 4)\n",
    "        return quat_params, r_velocity, velocity, r_rot\n",
    "\n",
    "    def get_cont6d_params(positions):\n",
    "        skel = Skeleton(n_raw_offsets, kinematic_chain, \"cpu\")\n",
    "        # (seq_len, joints_num, 4)\n",
    "        quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)\n",
    "\n",
    "        '''Quaternion to continuous 6D'''\n",
    "        cont_6d_params = quaternion_to_cont6d_np(quat_params)\n",
    "        # (seq_len, 4)\n",
    "        r_rot = quat_params[:, 0].copy()\n",
    "        #     print(r_rot[0])\n",
    "        '''Root Linear Velocity'''\n",
    "        # (seq_len - 1, 3)\n",
    "        velocity = (positions[1:, 0] - positions[:-1, 0]).copy()\n",
    "        #     print(r_rot.shape, velocity.shape)\n",
    "        velocity = qrot_np(r_rot[1:], velocity)\n",
    "        '''Root Angular Velocity'''\n",
    "        # (seq_len - 1, 4)\n",
    "        r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))\n",
    "        # (seq_len, joints_num, 4)\n",
    "        return cont_6d_params, r_velocity, velocity, r_rot\n",
    "\n",
    "    cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)\n",
    "    positions = get_rifke(positions)\n",
    "\n",
    "    #     trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)\n",
    "    #     r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])\n",
    "\n",
    "    # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')\n",
    "    # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')\n",
    "    # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')\n",
    "    # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')\n",
    "    # plt.xlabel('x')\n",
    "    # plt.ylabel('z')\n",
    "    # plt.axis('equal')\n",
    "    # plt.show()\n",
    "\n",
    "    '''Root height'''\n",
    "    root_y = positions[:, 0, 1:2]\n",
    "\n",
    "    '''Root rotation and linear velocity'''\n",
    "    # (seq_len-1, 1) rotation velocity along y-axis\n",
    "    # (seq_len-1, 2) linear velovity on xz plane\n",
    "    r_velocity = np.arcsin(r_velocity[:, 2:3])\n",
    "    l_velocity = velocity[:, [0, 2]]\n",
    "    #     print(r_velocity.shape, l_velocity.shape, root_y.shape)\n",
    "    root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)\n",
    "\n",
    "    '''Get Joint Rotation Representation'''\n",
    "    # (seq_len, (joints_num-1) *6) quaternion for skeleton joints\n",
    "    rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)\n",
    "\n",
    "    '''Get Joint Rotation Invariant Position Represention'''\n",
    "    # (seq_len, (joints_num-1)*3) local joint position\n",
    "    ric_data = positions[:, 1:].reshape(len(positions), -1)\n",
    "\n",
    "    '''Get Joint Velocity Representation'''\n",
    "    # (seq_len-1, joints_num*3)\n",
    "    local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),\n",
    "                        global_positions[1:] - global_positions[:-1])\n",
    "    local_vel = local_vel.reshape(len(local_vel), -1)\n",
    "\n",
    "    data = root_data\n",
    "    data = np.concatenate([data, ric_data[:-1]], axis=-1)\n",
    "    data = np.concatenate([data, rot_data[:-1]], axis=-1)\n",
    "    #     print(data.shape, local_vel.shape)\n",
    "    data = np.concatenate([data, local_vel], axis=-1)\n",
    "    data = np.concatenate([data, feet_l, feet_r], axis=-1)\n",
    "\n",
    "    return data, global_positions, positions, l_velocity\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Recover global angle and positions for rotation data\n",
    "# root_rot_velocity (B, seq_len, 1)\n",
    "# root_linear_velocity (B, seq_len, 2)\n",
    "# root_y (B, seq_len, 1)\n",
    "# ric_data (B, seq_len, (joint_num - 1)*3)\n",
    "# rot_data (B, seq_len, (joint_num - 1)*6)\n",
    "# local_velocity (B, seq_len, joint_num*3)\n",
    "# foot contact (B, seq_len, 4)\n",
    "def recover_root_rot_pos(data):\n",
    "    rot_vel = data[..., 0]\n",
    "    r_rot_ang = torch.zeros_like(rot_vel).to(data.device)\n",
    "    '''Get Y-axis rotation from rotation velocity'''\n",
    "    r_rot_ang[..., 1:] = rot_vel[..., :-1]\n",
    "    r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)\n",
    "\n",
    "    r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)\n",
    "    r_rot_quat[..., 0] = torch.cos(r_rot_ang)\n",
    "    r_rot_quat[..., 2] = torch.sin(r_rot_ang)\n",
    "\n",
    "    r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)\n",
    "    r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]\n",
    "    '''Add Y-axis rotation to root position'''\n",
    "    r_pos = qrot(qinv(r_rot_quat), r_pos)\n",
    "\n",
    "    r_pos = torch.cumsum(r_pos, dim=-2)\n",
    "\n",
    "    r_pos[..., 1] = data[..., 3]\n",
    "    return r_rot_quat, r_pos\n",
    "\n",
    "\n",
    "def recover_from_rot(data, joints_num, skeleton):\n",
    "    r_rot_quat, r_pos = recover_root_rot_pos(data)\n",
    "\n",
    "    r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)\n",
    "\n",
    "    start_indx = 1 + 2 + 1 + (joints_num - 1) * 3\n",
    "    end_indx = start_indx + (joints_num - 1) * 6\n",
    "    cont6d_params = data[..., start_indx:end_indx]\n",
    "    #     print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)\n",
    "    cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)\n",
    "    cont6d_params = cont6d_params.view(-1, joints_num, 6)\n",
    "\n",
    "    positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)\n",
    "\n",
    "    return positions\n",
    "\n",
    "\n",
    "def recover_from_ric(data, joints_num):\n",
    "    r_rot_quat, r_pos = recover_root_rot_pos(data)\n",
    "    positions = data[..., 4:(joints_num - 1) * 3 + 4]\n",
    "    positions = positions.view(positions.shape[:-1] + (-1, 3))\n",
    "\n",
    "    '''Add Y-axis rotation to local joints'''\n",
    "    positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)\n",
    "\n",
    "    '''Add root XZ to joints'''\n",
    "    positions[..., 0] += r_pos[..., 0:1]\n",
    "    positions[..., 2] += r_pos[..., 2:3]\n",
    "\n",
    "    '''Concate root and joints'''\n",
    "    positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)\n",
    "\n",
    "    return positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The given data is used to double check if you are on the right track.\n",
    "reference1 = np.load('./HumanML3D/new_joints/012314.npy')\n",
    "reference2 = np.load('./HumanML3D/new_joint_vecs/012314.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▏         | 417/29232 [00:26<24:26, 19.65it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "011059.npy\n",
      "cannot reshape array of size 0 into shape (0,newaxis)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▍         | 1400/29232 [01:16<21:22, 21.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M009707.npy\n",
      "cannot reshape array of size 0 into shape (0,newaxis)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 59%|█████▉    | 17203/29232 [14:08<08:35, 23.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "009707.npy\n",
      "cannot reshape array of size 0 into shape (0,newaxis)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 25842/29232 [21:14<02:18, 24.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M011059.npy\n",
      "cannot reshape array of size 0 into shape (0,newaxis)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 29232/29232 [24:06<00:00, 20.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total clips: 29232, Frames: 4117392, Duration: 3431.160000m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "For HumanML3D Dataset\n",
    "'''\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    example_id = \"000021\"\n",
    "    # Lower legs\n",
    "    l_idx1, l_idx2 = 5, 8\n",
    "    # Right/Left foot\n",
    "    fid_r, fid_l = [8, 11], [7, 10]\n",
    "    # Face direction, r_hip, l_hip, sdr_r, sdr_l\n",
    "    face_joint_indx = [2, 1, 17, 16]\n",
    "    # l_hip, r_hip\n",
    "    r_hip, l_hip = 2, 1\n",
    "    joints_num = 22\n",
    "    # ds_num = 8\n",
    "    data_dir = './joints/'\n",
    "    save_dir1 = './HumanML3D/new_joints/'\n",
    "    save_dir2 = './HumanML3D/new_joint_vecs/'\n",
    "    \n",
    "    os.makedirs(save_dir1, exist_ok=True)\n",
    "    os.makedirs(save_dir2, exist_ok=True)\n",
    "\n",
    "    n_raw_offsets = torch.from_numpy(t2m_raw_offsets)\n",
    "    kinematic_chain = t2m_kinematic_chain\n",
    "\n",
    "    # Get offsets of target skeleton\n",
    "    example_data = np.load(os.path.join(data_dir, example_id + '.npy'))\n",
    "    example_data = example_data.reshape(len(example_data), -1, 3)\n",
    "    example_data = torch.from_numpy(example_data)\n",
    "    tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')\n",
    "    # (joints_num, 3)\n",
    "    tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])\n",
    "    # print(tgt_offsets)\n",
    "\n",
    "    source_list = os.listdir(data_dir)\n",
    "    frame_num = 0\n",
    "    for source_file in tqdm(source_list):\n",
    "        source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]\n",
    "        try:\n",
    "            data, ground_positions, positions, l_velocity = process_file(source_data, 0.002)\n",
    "            rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)\n",
    "            np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy())\n",
    "            np.save(pjoin(save_dir2, source_file), data)\n",
    "            frame_num += data.shape[0]\n",
    "        except Exception as e:\n",
    "            print(source_file)\n",
    "            print(e)\n",
    "#         print(source_file)\n",
    "#         break\n",
    "\n",
    "    print('Total clips: %d, Frames: %d, Duration: %fm' %\n",
    "          (len(source_list), frame_num, frame_num / 20 / 60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check if your data is correct. If it's aligned with the given reference, then it is right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "reference1_1 = np.load('./HumanML3D/new_joints/012314.npy')\n",
    "reference2_1 = np.load('./HumanML3D/new_joint_vecs/012314.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abs(reference1 - reference1_1).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abs(reference2 - reference2_1).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch_render]",
   "language": "python",
   "name": "conda-env-torch_render-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
